from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class NMTEncoder(nn.Module):
    def __init__(self,num_embeddings,embedding_size,rnn_hidden_size):
        super(NMTEncoder,self).__init__()

        self.source_embedding=nn.Embedding(num_embeddings,embedding_size,padding_idx=0)
        self.birnn=nn.GRU(embedding_size,rnn_hidden_size,bidirectional=True,batch_first=True)

    def forward(self,x_source,x_lengths):
        x_embedded=self.source_embedding(x_source)
        x_packed=pack_padded_sequence(x_embedded,x_lengths.detach().cpu().numpy(),batch_first=True)
        #隐含状态与输出
        x_birnn_out,x_birnn_h=self.birnn(x_packed)

        x_birnn_h=x_birnn_h.permute(1,0,2)

        x_birnn_h=x_birnn_h.contiguous().view(x_birnn_h.size(0),-1)
        #把压缩后的序列填充回来
        x_unpacked,_=pad_packed_sequence(x_birnn_out,batch_first=True)
        return x_unpacked,x_birnn_h

